MXNet model zoo 多网络推理

预备工作

  • 从MXNET model zoo 中下载对应的模型参数及json配置文件 链接
  • 下载 synset.txt 便于程序读取分类结果

源码实现

1
2
3
4
import mxnet as mx
import numpy as np
import cv2,sys,time
from collections import namedtuple
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def loadModel(modelname):
t1 = time.time()
sym, arg_params, aux_params = mx.model.load_checkpoint(modelname, 0)
t2 = time.time()
t = 1000*(t2-t1)
print("Loaded in %2.2f milliseconds" % t)
# arg_params['prob_label'] = mx.nd.array([0])
# arg_params['softmax_label'] = mx.nd.array([0])
mod = mx.mod.Module(symbol=sym , context=mx.gpu() , label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
mod.set_params(arg_params,aux_params,allow_missing=True)
return mod

def loadCategories():
synsetfile = open('../picture/synset.txt', 'r')
synsets = []
for l in synsetfile:
synsets.append(l.rstrip())
return synsets

def prepareNDArray(filename):
img = cv2.imread(filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224,))
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
img = img[np.newaxis, :]
return mx.nd.array(img)


def predict(filename, model, categories, n):
array = prepareNDArray(filename)
Batch = namedtuple('Batch', ['data'])
t1 = time.time()
model.forward(Batch([array]))
t2 = time.time()
t = 1000*(t2-t1)
print("Predicted in %2.2f millsecond" % t)
prob = model.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)
sortedprobindex = np.argsort(prob)[::-1]
topn = []
for i in sortedprobindex[0:n]:
topn.append((prob[i], categories[i]))
return topn

def init(modelname):
model = loadModel(modelname)
cats = loadCategories()
return model, cats

实验结果

1
2
3
4
5
6
7
8
9
10
11
12
13
filename = sys.argv[1]
print("*** Inception v3")
inceptionv3,c = init("../model/Inception-BN/Inception-BN")
print(predict('../picture/cat.jpg',inceptionv3,c,1))
print("*** squeezenet_v1.0")
squeeze_v1_0,c = init("../model/squeezenet/squeezenet_v1.0")
print(predict('../picture/cat.jpg',squeeze_v1_0,c,1))
print("*** squeezenet_v1.1")
squeeze_v1_1,c = init("../model/squeezenet/squeezenet_v1.1")
print(predict('../picture/cat.jpg',squeeze_v1_1,c,1))
print("*** nin")
nin,c = init("../model/nin/nin")
print(predict('../picture/cat.jpg',nin,c,1))
*** Inception v3
Loaded in 20.42 milliseconds
Predicted in 0.19 millsecond
[(0.34680301, 'n02112018 Pomeranian')]
*** squeezenet_v1.0
Loaded in 2.94 milliseconds
Predicted in 0.08 millsecond
[(0.22845435, 'n02326432 hare')]
*** squeezenet_v1.1
Loaded in 2.93 milliseconds
Predicted in 0.08 millsecond
[(0.71776724, 'n02123045 tabby, tabby cat')]
*** nin
Loaded in 5.85 milliseconds
Predicted in 0.14 millsecond
[(0.67466462, 'n02119022 red fox, Vulpes vulpes')]
-------------本文结束 感谢您的阅读-------------
0%